{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# training Mamba on the TinyHome task.\n",
    "# this is a simplified (and not yet finished) reproduction of the picocrafter experiment done by @fchollet\n",
    "# (you can look up this experiment on his X, he talked about it in Nov 2023)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import time\n",
    "from IPython.display import clear_output\n",
    "\n",
    "from example.tinyhome import TinyHomeEngineV1, print_grid, print_act\n",
    "from example.buffer import ReplayBuffer\n",
    "\n",
    "from mambapy.mamba_lm import MambaLM, MambaLMConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "L = 5\n",
    "num_actions = 5\n",
    "num_obs_type = 4\n",
    "\n",
    "nb_instances = 512\n",
    "steps = 10000\n",
    "\n",
    "envs = TinyHomeEngineV1(B=nb_instances, h=L, w=L)\n",
    "buffer = ReplayBuffer(num_envs=nb_instances, capacity=int(1e6), obs_dim=L*L, act_dim=num_actions)\n",
    "\n",
    "obs = envs.reset()\n",
    "\n",
    "for _ in range(steps):\n",
    "    a = torch.randint(low=0, high=num_actions, size=(nb_instances,))\n",
    "    next_obs, rew = envs.step(a)\n",
    "\n",
    "    buffer.store(obs.view(-1, L*L), a, rew.squeeze(1))\n",
    "    obs = next_obs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/alex/miniconda3/envs/torch23/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "config = MambaLMConfig(d_model=16, n_layers=4, vocab_size=num_actions+num_obs_type, pad_vocab_size_multiple=num_actions+num_obs_type)\n",
    "model = MambaLM(config).to(device)\n",
    "optim = torch.optim.AdamW(model.parameters(), lr=3e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "13664"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum([p.numel() for p in model.parameters()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6.719587326049805\n",
      "0.38883084058761597\n",
      "0.26524049043655396\n",
      "0.22898846864700317\n",
      "0.21574825048446655\n",
      "0.1874855011701584\n",
      "0.1682835817337036\n",
      "0.14247804880142212\n",
      "0.1338169276714325\n",
      "0.10260577499866486\n"
     ]
    }
   ],
   "source": [
    "for i in range(1000):    \n",
    "    B, T = 64, 10\n",
    "    batch = buffer.sample(B, T)\n",
    "\n",
    "    obs = torch.tensor(batch['obs']).long().to(device)\n",
    "    act = torch.tensor(batch['act']).long().to(device)\n",
    "\n",
    "    tokens = torch.cat([obs, torch.zeros(B, T, 1, dtype=torch.int, device='cuda')], dim=2).view(B, 26*T) # (B, 26T)\n",
    "    tokens[:, 25::26] = act+4\n",
    "\n",
    "    input = tokens\n",
    "    output = tokens[:, 1:].reshape(-1)\n",
    "\n",
    "    logits = model(tokens[:, :-1]) # (B, 26T-1, vocab_size)\n",
    "    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), output)\n",
    "\n",
    "    optim.zero_grad()\n",
    "    loss.backward()\n",
    "    optim.step()\n",
    "\n",
    "    if i%100==0:\n",
    "        print(loss.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens = torch.ones(1, 2, dtype=torch.long).cuda() # (B=1, 2)\n",
    "T = 20\n",
    "for _ in range(26*T-2):\n",
    "  logits = model(tokens)[0, -1]\n",
    "  probs = F.softmax(logits, dim=0)\n",
    "  sampled = torch.multinomial(probs, num_samples=1, replacement=True)\n",
    "  tokens = torch.cat([tokens, sampled.view(1, 1)], dim=1)\n",
    "tokens = tokens.view(T, 26) # (T, 26)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#####\n",
      "#   #\n",
      "#   #\n",
      "#G@ #\n",
      "#####\n",
      "\n",
      "\n",
      "E\n"
     ]
    }
   ],
   "source": [
    "for timestep in tokens:\n",
    "  grid = timestep[:-1].view(1, 5, 5)\n",
    "  a = timestep[-1]-4\n",
    "\n",
    "  clear_output(wait=True)\n",
    "  print_grid(grid)\n",
    "  print_act(a.item())\n",
    "  time.sleep(0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch23",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
